In [ ]:
epochs = 15

Tutorial 3 - Folded Split Neural Network

Recap: The previous tutorial looked at building a SplitNN. The NN was split into three segments on three seperate hosts, where one host had data and another had labels. However, what if someone has data and labels in the same place?

Description: Here we fold a multilayer SplitNN in on itself in order to accomodate the data nd labels being in the same place. We demonstrate the SplitNN class with a 3 segment distribution. This time,

  • Alice
    • Has Model Segment 1
    • Has Model Segment 3
    • Has the handwritten images
    • Has the image labels
  • Bob
    • Has model Segment 2

Again, we use the exact same model as we used in the previous tutorial and see the same accuracy. Neither Alice nor Bob have the full model and Bob can't see Alice's data.

Author:


In [ ]:
class SplitNN:
    def __init__(self, models, optimizers):
        self.models = models
        self.optimizers = optimizers
        
    def forward(self, x):
        a = []
        remote_a = []
        
        a.append(models[0](x))
        if a[-1].location == models[1].location:
            remote_a.append(a[-1].detach().requires_grad_())
        else:
            remote_a.append(a[-1].detach().move(models[1].location).requires_grad_())

        i=1    
        while i < (len(models)-1):
            
            a.append(models[i](remote_a[-1]))
            if a[-1].location == models[i+1].location:
                remote_a.append(a[-1].detach().requires_grad_())
            else:
                remote_a.append(a[-1].detach().move(models[i+1].location).requires_grad_())
            
            i+=1
        
        a.append(models[i](remote_a[-1]))
        self.a = a
        self.remote_a = remote_a
        
        return a[-1]
    
    def backward(self):
        a=self.a
        remote_a=self.remote_a
        optimizers = self.optimizers
        
        i= len(models)-2   
        while i > -1:
            if remote_a[i].location == a[i].location:
                grad_a = remote_a[i].grad.copy()
            else:
                grad_a = remote_a[i].grad.copy().move(a[i].location)
            a[i].backward(grad_a)
            i-=1

    
    def zero_grads(self):
        for opt in optimizers:
            opt.zero_grad()
        
    def step(self):
        for opt in optimizers:
            opt.step()

In [ ]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

In [ ]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [ ]:
torch.manual_seed(0)

# Define our model segments

input_size = 784
hidden_sizes = [128, 640]
output_size = 10

models = [
    nn.Sequential(
                nn.Linear(input_size, hidden_sizes[0]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[1], output_size),
                nn.LogSoftmax(dim=1)
    )
]

# Create optimisers for each segment and link to them
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")

# Send Model Segments to model locations
model_locations = [alice, bob, alice]
for model, location in zip(models, model_locations):
    model.send(location)

#Instantiate a SpliNN class with our distributed segments and their respective optimizers
splitNN = SplitNN(models, optimizers)

In [ ]:
def train(x, target, splitNN):
    
    #1) Zero our grads
    splitNN.zero_grads()
    
    #2) Make a prediction
    pred = splitNN.forward(x)
    
    #3) Figure out how much we missed by
    criterion = nn.NLLLoss()
    loss = criterion(pred, target)
    
    #4) Backprop the loss on the end layer
    loss.backward()
    
    #5) Feed Gradients backward through the nework
    splitNN.backward()
    
    #6) Change the weights
    splitNN.step()
    
    return loss

In [ ]:
for i in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        images = images.send(models[0].location)
        images = images.view(images.shape[0], -1)
        labels = labels.send(models[-1].location)
        loss = train(images, labels, splitNN)
        running_loss += loss.get()

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))